{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 值迭代算法\n",
    "作者：stzhao\n",
    "github: https://github.com/zhaoshitian"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 一、定义环境\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import sys,os\n",
    "curr_path = os.path.abspath('')\n",
    "parent_path = os.path.dirname(curr_path)\n",
    "sys.path.append(parent_path)\n",
    "from envs.simple_grid import DrunkenWalkEnv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "metadata": {},
   "outputs": [],
   "source": [
    "def all_seed(env,seed = 1):\n",
    "    ## 这个函数主要是为了固定随机种子\n",
    "    import numpy as np\n",
    "    import random\n",
    "    import os\n",
    "    env.seed(seed) \n",
    "    np.random.seed(seed)\n",
    "    random.seed(seed)\n",
    "    os.environ['PYTHONHASHSEED'] = str(seed) \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "env = DrunkenWalkEnv(map_name=\"theAlley\")\n",
    "all_seed(env, seed = 1) # 设置随机种子为1"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 二、价值迭代算法\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "def value_iteration(env, theta=0.005, discount_factor=0.9):\n",
    "    Q = np.zeros((env.nS, env.nA)) # 初始化一个Q表格\n",
    "    count = 0\n",
    "    while True:\n",
    "        delta = 0.0\n",
    "        Q_tmp = np.zeros((env.nS, env.nA))\n",
    "        for state in range(env.nS):\n",
    "            for a in range(env.nA):\n",
    "                accum = 0.0\n",
    "                reward_total = 0.0\n",
    "                for prob, next_state, reward, done in env.P[state][a]:\n",
    "                    accum += prob* np.max(Q[next_state, :])\n",
    "                    reward_total += prob * reward\n",
    "                Q_tmp[state, a] = reward_total + discount_factor * accum\n",
    "                delta = max(delta, abs(Q_tmp[state, a] - Q[state, a]))\n",
    "        Q = Q_tmp\n",
    "        \n",
    "        count += 1\n",
    "        if delta < theta or count > 100: # 这里设置了即使算法没有收敛，跑100次也退出循环\n",
    "            break \n",
    "    return Q"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[[2.25015697e+22 2.53142659e+22 4.50031394e+22 2.53142659e+22]\n",
      " [2.81269621e+22 5.41444021e+22 1.01257064e+23 5.41444021e+22]\n",
      " [6.32856648e+22 1.21824905e+23 2.27828393e+23 1.21824905e+23]\n",
      " [1.42392746e+23 2.74106036e+23 5.12613885e+23 2.74106036e+23]\n",
      " [3.20383678e+23 5.76690620e+23 1.15338124e+24 5.76690620e+23]\n",
      " [7.20863276e+23 1.38766181e+24 2.59510779e+24 1.38766181e+24]\n",
      " [1.62194237e+24 3.12223906e+24 5.83899253e+24 3.12223906e+24]\n",
      " [3.64937033e+24 7.02503789e+24 1.31377332e+25 7.02503789e+24]\n",
      " [8.21108325e+24 1.47799498e+25 2.95598997e+25 1.47799498e+25]\n",
      " [1.84749373e+25 3.55642543e+25 6.65097743e+25 3.55642543e+25]\n",
      " [4.15686089e+25 8.00195722e+25 1.49646992e+26 8.00195722e+25]\n",
      " [9.35293701e+25 1.80044037e+26 3.36705732e+26 1.80044037e+26]\n",
      " [5.89235032e+26 7.36543790e+26 7.57587898e+26 7.36543790e+26]]\n"
     ]
    }
   ],
   "source": [
    "Q = value_iteration(env)\n",
    "print(Q)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]\n"
     ]
    }
   ],
   "source": [
    "policy = np.zeros([env.nS, env.nA]) # 初始化一个策略表格\n",
    "for state in range(env.nS):\n",
    "    best_action = np.argmax(Q[state, :]) #根据价值迭代算法得到的Q表格选择出策略\n",
    "    policy[state, best_action] = 1\n",
    "\n",
    "policy = [int(np.argwhere(policy[i]==1)) for i in range(env.nS) ]\n",
    "print(policy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 三、测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_episode = 1000 # 测试1000次\n",
    "def test(env,policy):\n",
    "    \n",
    "    rewards = []  # 记录所有回合的奖励\n",
    "    success = []  # 记录该回合是否成功走到终点\n",
    "    for i_ep in range(num_episode):\n",
    "        ep_reward = 0  # 记录每个episode的reward\n",
    "        state = env.reset()  # 重置环境, 重新开一局（即开始新的一个回合） 这里state=0\n",
    "        while True:\n",
    "            action = policy[state]  # 根据算法选择一个动作\n",
    "            next_state, reward, done, _ = env.step(action)  # 与环境进行一个交互\n",
    "            state = next_state  # 更新状态\n",
    "            ep_reward += reward\n",
    "            if done:\n",
    "                break\n",
    "        if state==12: # 即走到终点\n",
    "            success.append(1)\n",
    "        else:\n",
    "            success.append(0)\n",
    "        rewards.append(ep_reward)\n",
    "    acc_suc = np.array(success).sum()/num_episode\n",
    "    print(\"测试的成功率是：\", acc_suc)\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试的成功率是： 0.64\n"
     ]
    }
   ],
   "source": [
    "test(env, policy)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.6 ('RL')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "88a829278351aa402b7d6303191a511008218041c5cfdb889d81328a3ea60fbc"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
